import numpy as np
import time_dependent as td
import scipy.optimize as opt
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import menu_cost as mc
import multi_product as mp
import plot_specs as specs
import general_equilibrium as ge
import general_equilibrium_full as ge_full
import jacobian as jac
import xlsxwriter
import statsmodels.api as sm

# %% Setup

colors_td_section = ['black', 'firebrick']  # colors for td models from section 2
colors_equivalence = {
    'intensive': 'purple',
    'extensive': 'tab:blue',
    'total': 'black',
    'actual_hazards': 'green',
    'menu_cost': 'mediumblue',
    'td': 'firebrick'}
colors_Es = ['silver', 'darkgrey', 'dimgrey', 'black']  # color for plotting the Es
colors_duration = ['lightskyblue', 'cornflowerblue', 'royalblue', 'blue', 'darkblue']

add_titles = False
folder = 'figures/'

T = 500
T_extend = 3000
h = 1e-4
nx = 500
n_periods = 4

x_grid_large = np.linspace(-0.75, 0.75, 1000)

freq_data_monthly = 0.087
freq_data = 1 - (1 - freq_data_monthly) ** (12 / n_periods)
med_abs_dp_data = 0.085

freq_sectors_monthly = np.array([0.916, 0.494, 0.437, 0.254, 0.213, 0.217, 0.119, 0.084, 0.065, 0.062, 0.061, 0.049, 0.036, 0.029])
freq_sectors = 1 - (1 - freq_sectors_monthly) ** (12 / n_periods)
mean_abs_dp_sectors = np.array([0.049, 0.064, 0.184, 0.159, 0.089, 0.040, 0.114, 0.067, 0.101, 0.088, 0.102, 0.081, 0.124, 0.135])
weights_sectors = np.array([0.077, 0.053, 0.055, 0.059, 0.083, 0.077, 0.137, 0.075, 0.05, 0.078, 0.036, 0.076, 0.065, 0.079])

beta_quarterly = 0.99
beta = beta_quarterly ** (4 / n_periods)

parameters_base = {
    'mu': 0,
    'sig_k': 0.05,
    'tau': 1,
    'r': 1 / beta - 1,
    'distribution': 'normal',
    'sig': 0.05,
    'elast': None,
    'gap': 0,
    'discount': 0,
    'shift': 0,
}

mc_autocorr = 0.8

shock_persistence = {'nominal': [0.3, 0.6, 1],
                     'real': [0.3, 0.6, 0.8]}


def plot_jacobian(J, ax, cols=[0, 10, 20], cut=None, color=None, linestyle='-', alpha=None, alpha_list=None):
    if cut is None:
        cut = max(cols) + 21
    if alpha_list is None:
        alpha_list = len(cols) * [alpha]
    for i in range(len(cols)):
        c = cols[i]
        ax.plot(J[:cut, c], color=color, linestyle=linestyle, alpha=alpha_list[i])


def ar_irf(ar_roots, shock_sd, T, return_coefs=False):
    # constructs AR IRF
    A = np.eye(T)
    for i in range(len(ar_roots)):
        A[1:, :] -= ar_roots[i] * A[0:-1, :]
    rho = -A[1:len(ar_roots)+1, 0]
    b = np.zeros(T)
    b[0] = 1
    irf = shock_sd * np.linalg.solve(A, b)
    if return_coefs:
        return irf, rho
    else:
        return irf


def compute_ge_irfs(J, rho_i, T, model='nk'):
    if model == 'sw':
        block_list = ge.blocks_sw + [ge.phillips_curve(J)]
        unknowns = ge.unknowns_sw
        targets = ge.targets_sw
        ss = ge.ss_sw
    else:
        block_list = ge.blocks_nk + [ge.phillips_curve(J)]
        unknowns = ge.unknowns_nk
        targets = ge.targets_nk
        ss = ge.ss_nk
    G = jac.get_G(block_list=block_list, exogenous=['eps_i'], unknowns=unknowns, targets=targets, T=T, ss=ss)
    shock = 0.25 * (rho_i ** np.arange(T))
    irf = {k: G[k]['eps_i'] @ shock for k in G.keys()}
    return irf


def compute_ge_full_irfs(J, parameters, ss, rho_i, T):
    ss_nk = ge_full.compute_ss(parameters, ss)
    block_list = [ge_full.model, ge_full.pc_inputs, ge_full.phillips_curve(J)]
    G = jac.get_G(block_list=block_list, exogenous=['eps_i'], unknowns=ge_full.unknowns, targets=ge_full.targets, T=T, ss=ss_nk)
    shock = 0.25 * (rho_i ** np.arange(T))
    irf = {k: G[k]['eps_i'] @ shock for k in G.keys()}
    return irf


def write_table(worksheet, entries, col, row_start=0):
    for i, entry in enumerate(entries):
        worksheet.write(row_start + i, col, entry)


# %% Part 1: figures 1, 2, 3, B1

freq = [0.2, 0.3]
freq_taylor = [1/5, 1/4]  # must be 1 / integer

J_td = {}

J_td['calvo'] = {}
J_td['calvo']['nominal'] = [td.calvo_jacobian(1 - f, beta, T) for f in freq]
J_td['calvo']['real'] = [td.nominal_to_real(J) for J in J_td['calvo']['nominal']]

J_td['increasing_hazards'] = {}
hazards = 1 - np.exp(-0.2 * np.arange(1, T + 1))
def freq_from_hazards(hazards):
    f = td.hazards_to_price_age_distribution(np.minimum(hazards, 1))
    return np.sum(f * hazards)
scale = [opt.fsolve(lambda scale: freq_from_hazards(scale * hazards) - f, x0=1) for f in freq]  # scales hazards to match the right adjustment frequency
J_td['increasing_hazards']['nominal'] = [td.general_td_jacobian(f=td.hazards_to_price_age_distribution(s * hazards), beta=beta) for s in scale]
J_td['increasing_hazards']['real'] = [td.nominal_to_real(J) for J in J_td['increasing_hazards']['nominal']]

J_td['taylor'] = {}
J_td['taylor']['nominal'] = [td.taylor_jacobian(f, beta, T) for f in freq_taylor]

for model in J_td.keys():
    J_td[model]['real'] = [td.nominal_to_real(J) for J in J_td[model]['nominal']]

alpha_list = [1, 0.7, 0.55]
cols = [0, 10, 20]

legend_calvo = [
    f'Freq. = {freq[0]}, s = {cols[0]}',
    f'Freq. = {freq[0]}, s = {cols[1]}',
    f'Freq. = {freq[0]}, s = {cols[2]}',
    f'Freq. = {freq[1]}, s = {cols[0]}',
    f'Freq. = {freq[1]}, s = {cols[1]}',
    f'Freq. = {freq[1]}, s = {cols[2]}',
]
legend_taylor = [
    f'Freq. = {freq_taylor[0]}, s = {cols[0]}',
    f'Freq. = {freq_taylor[0]}, s = {cols[1]}',
    f'Freq. = {freq_taylor[0]}, s = {cols[2]}',
    f'Freq. = {freq_taylor[1]}, s = {cols[0]}',
    f'Freq. = {freq_taylor[1]}, s = {cols[1]}',
    f'Freq. = {freq_taylor[1]}, s = {cols[2]}',
]

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_td['calvo']['nominal'][0], ax, color=colors_td_section[0], alpha_list=alpha_list, cols=cols)
plot_jacobian(J_td['calvo']['nominal'][1], ax, color=colors_td_section[1], linestyle='--', alpha_list=alpha_list, cols=cols)
ax.set_xlim([0, 40])
ax.set_ylim([-0.01, 0.3])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
ax.legend(legend_calvo, ncol=2, columnspacing=0.5, handlelength=1.5, fontsize=16)
specs.save_figure(fig, 'figure_1_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_td['increasing_hazards']['nominal'][0], ax, color=colors_td_section[0], alpha_list=alpha_list, cols=cols)
plot_jacobian(J_td['increasing_hazards']['nominal'][1], ax, color=colors_td_section[1], linestyle='--', alpha_list=alpha_list, cols=cols)
ax.set_xlim([0, 40])
ax.set_ylim([-0.01, 0.3])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
specs.save_figure(fig, 'figure_1_b', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(J_td['calvo']['nominal'][1] @ np.eye(T, 1), color=colors_td_section[0])
ax.plot(J_td['increasing_hazards']['nominal'][1] @ np.eye(T, 1), color=colors_td_section[1], linestyle='--')
ax.set_xlim([0, 10])
ax.set_ylim([0, 0.1])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
specs.save_figure(fig, 'figure_2_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(J_td['calvo']['nominal'][1] @ np.ones(T), color=colors_td_section[0])
ax.plot(J_td['increasing_hazards']['nominal'][1] @ np.ones(T), color=colors_td_section[1], linestyle='--')
ax.set_xlim([0, 10])
ax.set_ylim([0, 1.05])
ax.legend(['Calvo', 'Increasing hazards'], loc=4)
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
specs.save_figure(fig, 'figure_2_b', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_td['calvo']['real'][0], ax, color=colors_td_section[0], alpha_list=alpha_list, cols=cols)
plot_jacobian(J_td['calvo']['real'][1], ax, color=colors_td_section[1], linestyle='--', alpha_list=alpha_list, cols=cols)
ax.set_xlim([0, 40])
ax.set_ylim([-0.01, 0.20])
ax.set_xlabel('Quarters')
ax.set_ylabel('Inflation')
ax.legend(legend_calvo, ncol=2, columnspacing=0.5, handlelength=1.5, fontsize=16)
specs.save_figure(fig, 'figure_3_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_td['increasing_hazards']['real'][0], ax, color=colors_td_section[0], alpha_list=alpha_list, cols=cols)
plot_jacobian(J_td['increasing_hazards']['real'][1], ax, color=colors_td_section[1], linestyle='--', alpha_list=alpha_list, cols=cols)
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Inflation')
specs.save_figure(fig, 'figure_3_b', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_td['taylor']['nominal'][0], ax, color=colors_td_section[0], alpha_list=alpha_list, cols=cols)
plot_jacobian(J_td['taylor']['nominal'][1], ax, color=colors_td_section[1], linestyle='--', alpha_list=alpha_list, cols=cols)
ax.set_xlim([0, 40])
ax.set_ylim([-0.01, 0.4])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
ax.legend(legend_taylor, ncol=2, columnspacing=0.5, handlelength=1.5, fontsize=16)
specs.save_figure(fig, 'figure_B1_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_td['taylor']['real'][0], ax, color=colors_td_section[0], alpha_list=alpha_list, cols=cols)
plot_jacobian(J_td['taylor']['real'][1], ax, color=colors_td_section[1], linestyle='--', alpha_list=alpha_list, cols=cols)
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Inflation')
specs.save_figure(fig, 'figure_B1_b', folder=folder)


# %% Part 2: figures 4, 5, 6, 7, 8, 10, 11, 12, table 1

model_names = ['GL', 'NS']

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}  # calibration targets

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
legend = ['Menu cost', 'Calvo']

rho_i = 0.5  # persistence of monetary shock in standard NK model
rho_i_sw = 0.15  # persistence of monetary shock in Smets-Wouters model

irf_nk = {model: {} for model in model_names}
irf_nk_approx = {model: {} for model in model_names}

workbook_calibration = xlsxwriter.Workbook('tables/table_1.xlsx')
worksheet_calibration = workbook_calibration.add_worksheet()
write_table(worksheet_calibration, ['Model', 'Menu cost', 'Shock volatility', 'Prob. of free adjustment', 'Frequency', 'Median adjustment size'], 0, row_start=0)

for model in model_names:

    if model == 'GL':
        share_free = 0.0
        col_table = 1
        fig_letter = 'a'
        fig_letter_2 = 'c'
    else:
        share_free = 0.75
        col_table = 2
        fig_letter = 'b'
        fig_letter_2 = 'd'

    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data})
    if model == 'GL':
        parameters.update({'mu_k': -5.122741752515663, 'sig': 0.045874035207593526})
    else:
        parameters.update({'mu_k': -2.9710458486831417, 'sig': 0.06035491762766833})

    # compute necessary objects
    J_nom, ss, J_ext, J_int, f_ext, f_int, w_ext, w_int = mc.compute_td_equivalence(parameters, nx, T, h=h)
    J_real = mc.compute_PC(J_nom, parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=True)
    J_ext_real, J_int_real = mc.compute_PC(J_ext, cut_back=True), mc.compute_PC(J_int, cut_back=True)
    E, E_prime, E_xbar, P, P_zero, P_xbar = mc.compute_Es(parameters, nx, ss, tmax=T)
    xbar = np.interp((1 + parameters['la'])/2, ss['q'][nx // 2:], ss['x_grid'][nx // 2:])

    f_agg = w_ext * f_ext + w_int * f_int
    hazards_ext = 1 - f_ext[1:] / np.maximum(f_ext[:-1], 1e-8)
    hazards_int = 1 - f_int[1:] / np.maximum(f_int[:-1], 1e-8)
    hazards_agg = 1 - f_agg[1:] / np.maximum(f_agg[:-1], 1e-8)

    write_table(worksheet_calibration, [model, np.exp(parameters['mu_k']), parameters['sig'], parameters['la'], ss['stats']['freq'], ss['stats']['med_abs_dp']], col_table, row_start=0)

    J_real_approx, theta_approx, _, _, dist_approx = mc.approx_jacobian(J_real, beta, Price=False, Nominal=False, Absolute=True)
    kappa = (1 - theta_approx) * (1 - beta * theta_approx) / theta_approx
    J_nom_approx = td.calvo_jacobian(theta_approx, beta, len(J_real_approx))

    print(f'--- {model} model ---')
    print(f'Approximation results:')
    print(f'freq. = {1 - theta_approx:.3f}')
    print(f'kappa = {kappa:.3f}')
    print()

    J = {'nominal': J_nom, 'real': J_real}
    J_td = {'ext': {'nominal': J_ext, 'real': J_ext_real}, 'int': {'nominal': J_int, 'real': J_int_real}}
    J_approx = {'nominal': J_nom_approx, 'real': J_real_approx}

    irf_nk[model] = compute_ge_irfs({'pi': {'gap': J_real}}, rho_i, T)
    irf_nk_approx[model] = compute_ge_irfs({'pi': {'gap': J_real_approx}}, rho_i, T)
    irf_sw = compute_ge_irfs({'pi': {'gap': J_real}}, rho_i_sw, T, model='sw')
    irf_sw_approx = compute_ge_irfs({'pi': {'gap': J_real_approx}}, rho_i_sw, T, model='sw')

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.set_xlim([0, 10])
    ax.set_ylim([-0.1, 1.1])
    ax.yaxis.tick_right()
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Price level')
    ax.plot(J_nom @ np.ones(T), color=colors_equivalence['total'], linestyle='-')
    ax.plot(w_ext * J_ext @ np.ones(T), color=colors_equivalence['extensive'], linestyle='--')
    ax.plot(w_int * J_int @ np.ones(T), color=colors_equivalence['intensive'], linestyle='-.')
    ax.set_yticks([0, np.minimum(w_int, w_ext), np.maximum(w_int, w_ext), 1])
    if w_int < w_ext:
        ax.set_yticklabels([0, f'{w_int:.2f}', f'{w_ext:.2f}', f'{1:.2f}'])
        for ticklabel, tickcolor in zip(plt.gca().get_yticklabels(), ['k', colors_equivalence['intensive'], colors_equivalence['extensive'], 'k']):
            ticklabel.set_color(tickcolor)
    else:
        ax.set_yticklabels([0, f'{w_ext:.2f}', f'{w_int:.2f}', f'{1:.2f}'])
        for ticklabel, tickcolor in zip(plt.gca().get_yticklabels(), ['k', colors_equivalence['extensive'], colors_equivalence['intensive'], 'k']):
            ticklabel.set_color(tickcolor)
    if model == 'NS':
        ax.legend(['Total', 'Extensive margin', 'Intensive margin'], loc=4)
    specs.save_figure(fig, 'figure_4_' + fig_letter, folder=folder)

    if model == 'NS':
        dx = ss['x_grid'][1] - ss['x_grid'][0]
        fig, ax = plt.subplots(figsize=specs.figsize_standard)
        ax.plot(ss['x_grid'], ss['g'] / dx, color='black', linestyle='-')
        ylim = ax.get_ylim()
        ax.plot([-xbar, -xbar], ylim, linestyle='--', color='grey', linewidth=2)
        ax.plot([xbar, xbar], ylim, linestyle='--', color='grey', linewidth=2)
        ax.set_ylim(ylim)
        ax.set_xlabel('Price gap')
        legend_dist = [f'Distribution', r'$\underline{x}, \overline{x}$']
        custom_lines_dist = [Line2D([0], [0], color='black', linestyle='-', linewidth=2),
                             Line2D([0], [0], color='grey', linestyle='--', linewidth=2)]
        ax.legend(custom_lines_dist, legend_dist)
        specs.save_figure(fig, 'figure_5_a', folder=folder)

        legend_Es = []
        linestyles = [':', '-.', '--', '-']
        idx_plot = (np.abs(ss['x_grid']) < xbar)
        t_plot = [0, 1, 2, 3]
        fig, ax = plt.subplots(figsize=specs.figsize_standard)
        for k, t in enumerate(t_plot):
            ax.plot(ss['x_grid'][idx_plot], E[t][idx_plot], color=colors_Es[k], linestyle=linestyles[k])
            legend_Es += [f'$E^{t}(x)$']
        ylim = ax.get_ylim()
        ax.plot([-xbar, -xbar], ylim, linestyle='--', color='grey', linewidth=2)
        ax.plot([xbar, xbar], ylim, linestyle='--', color='grey', linewidth=2)
        ax.set_ylim([-0.45 * xbar, 0.45 * xbar])
        ax.set_yticks([-0.06, -0.03, 0, 0.03, 0.06])
        ax.legend(legend_Es)
        specs.save_figure(fig, 'figure_5_b', folder=folder)

    survival_mc = ss['stats']['survival']
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(f_int[:7] / f_int[0], linestyle='-', color=colors_equivalence['intensive'])
    ax.plot(f_ext[:7] / f_ext[0], linestyle='--', color=colors_equivalence['extensive'])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Survival function')
    ax.set_xlim([0, 6])
    ax.plot(survival_mc, linestyle='-.', color=colors_equivalence['actual_hazards'])
    ax.plot(f_agg[:7] / f_agg[0], linestyle=':', color='black', linewidth=3.5)
    if model == 'NS':
        ax.legend(['Intensive margin ($\Phi^i$)', 'Extensive margin ($\Phi^e$)', 'Actual survival ($\Phi^{actual}$)', 'Weighted average'])
    specs.save_figure(fig, 'figure_6_' + fig_letter, folder=folder, close=True)

    hazards_mc = ss['stats']['hazards']
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(hazards_int[:7], linestyle='-', color=colors_equivalence['intensive'])
    ax.plot(hazards_ext[:7], linestyle='--', color=colors_equivalence['extensive'])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Adjustment hazards')
    ax.set_xlim([0, 6])
    ylim = list(ax.get_ylim())
    ylim[0] = 0.1 * np.floor(np.min(hazards_mc) / 0.1)
    ax.set_ylim(ylim)
    ax.plot(hazards_mc[:7], linestyle='-.', color=colors_equivalence['actual_hazards'])
    ax.plot(hazards_agg[:7], linestyle=':', color='black', linewidth=3.5)
    specs.save_figure(fig, 'figure_6_' + fig_letter_2, folder=folder)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J['nominal'], ax, color=colors_equivalence['menu_cost'])
    plot_jacobian(J_approx['nominal'], ax, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Price level')
    if model == 'NS':
        ax.legend(custom_lines, legend)
    specs.save_figure(fig, 'figure_7_' + fig_letter, folder=folder)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J['real'], ax, color=colors_equivalence['menu_cost'])
    plot_jacobian(J_approx['real'], ax, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Inflation')
    specs.save_figure(fig, 'figure_7_' + fig_letter_2, folder=folder)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    for p in shock_persistence['nominal']:
        shock = p ** np.arange(T)
        ax.plot(J['nominal'] @ shock, color=colors_equivalence['menu_cost'])
        ax.plot(J_approx['nominal'] @ shock, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 12])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Price level')
    if model == 'NS':
        ax.legend(custom_lines, legend)
    specs.save_figure(fig, 'figure_8_' + fig_letter, folder=folder)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    for p in shock_persistence['real']:
        shock = p ** np.arange(T)
        ax.plot(J['real'] @ shock, color=colors_equivalence['menu_cost'])
        ax.plot(J_approx['real'] @ shock, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 12])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Inflation')
    specs.save_figure(fig, 'figure_8_' + fig_letter_2, folder=folder)

    custom_lines_here = [Line2D([0], [0], color=colors_equivalence['extensive'], linestyle='-'),
                         Line2D([0], [0], color=colors_equivalence['intensive'], linestyle='--')]
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J_td['ext']['real'], ax, color=colors_equivalence['extensive'], linestyle='-')
    plot_jacobian(J_td['int']['real'], ax, color=colors_equivalence['intensive'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Inflation')
    if model == 'NS':
        ax.legend(custom_lines_here, ['Extensive', 'Intensive'])
    specs.save_figure(fig, 'figure_10_' + fig_letter, folder=folder)

    xlim = 8
    fig, ax = plt.subplots(nrows=2, figsize=specs.figsize_standard)
    ax[0].plot(irf_nk[model]['pi'], linestyle='-', color=colors_equivalence['menu_cost'])
    ax[0].plot(irf_nk_approx[model]['pi'], linestyle='--', color=colors_equivalence['td'])
    ax[0].set_xlim([0, xlim])
    ax[0].set_ylabel('Inflation')
    ax[1].plot(irf_nk[model]['y'], linestyle='-', color=colors_equivalence['menu_cost'])
    ax[1].plot(irf_nk_approx[model]['y'], linestyle='--', color=colors_equivalence['td'])
    ax[1].set_xlabel('Quarters')
    ax[1].set_xlim([0, xlim])
    ax[1].set_ylabel('Output')
    if model == 'NS':
        ax[0].legend(custom_lines, legend, loc=4)
        ax[1].set_yticks([-0.05, 0])
    specs.save_figure(fig, 'figure_11_' + fig_letter, folder=folder)

    xlim = 12
    fig, ax = plt.subplots(nrows=2, figsize=specs.figsize_standard)
    ax[0].plot(irf_sw['pi'], linestyle='-', color=colors_equivalence['menu_cost'])
    ax[0].plot(irf_sw_approx['pi'], linestyle='--', color=colors_equivalence['td'])
    ax[0].set_xlim([0, xlim])
    ax[0].set_ylabel('Inflation')
    ax[1].plot(irf_sw['y'], linestyle='-', color=colors_equivalence['menu_cost'])
    ax[1].plot(irf_sw_approx['y'], linestyle='--', color=colors_equivalence['td'])
    ax[1].set_xlabel('Quarters')
    ax[1].set_xlim([0, xlim])
    ax[1].set_ylabel('Output')
    if model == 'NS':
        ax[0].legend(custom_lines, legend, loc=4)
    specs.save_figure(fig, 'figure_12_' + fig_letter, folder=folder)


write_table(worksheet_calibration, ['Data', '', '', '', freq_data, med_abs_dp_data], col_table + 1, row_start=0)
workbook_calibration.close()


# %% Part 3: figures 9, C1

ndur = 20
duration_values = np.linspace(0.5, 8, ndur)
share_free_values = [0, 0.25, 0.5, 0.75, 0.9]
nshare = len(share_free_values)

duration_data = 1 / freq_data - 1

dist = {'Nominal': np.zeros((ndur, nshare)),
        'Real': np.zeros((ndur, nshare))}
theta = {'Nominal': np.zeros((ndur, nshare)),
         'Real': np.zeros((ndur, nshare)),
         'ALL': np.zeros((ndur, nshare))}
eigs_ratio = np.zeros((ndur, nshare))
duration = {'Nominal': np.zeros((ndur, nshare)),
            'Real': np.zeros((ndur, nshare)),
            'ALL': np.zeros((ndur, nshare))}
slope = {'Nominal': np.zeros(nshare),
         'Real': np.zeros(nshare),
         'ALL': np.zeros(nshare)}

iter = 0
for ishare, share_free in enumerate(share_free_values):
    indices = range(len(duration_values))

    for idur in indices:
        dur = duration_values[idur]
        iter += 1

        freq = 1 / (1 + dur)

        if idur == 0:
            mu_k_guess = -6
        else:
            mu_k_guess = parameters['mu_k']

        targs = {'freq': freq}
        initial_guess = {'mu_k': mu_k_guess}

        parameters = parameters_base.copy()
        parameters.update({'la': share_free * freq})

        parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx)
        J_nom, _, _, _, f_ext, f_int, w_ext, w_int = mc.compute_td_equivalence(parameters, nx, T, ss, h)
        J_real = mc.compute_PC(J_nom, parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=True)

        xbar = np.interp((1 + parameters['la']) / 2, ss['q'][nx // 2:], ss['x_grid'][nx // 2:])
        eigs = mc.compute_eigenvalues(xbar, parameters['la'], parameters['sig'], parameters['tau'])
        eigs_ratio[idur, ishare] = eigs[1] / eigs[0]

        _, theta['Real'][idur, ishare], _, _, dist['Real'][idur, ishare] = mc.approx_jacobian(J_real, beta, Price=False, Nominal=False, Absolute=True)

        y = ss['stats']['kurt_dp'] / ss['stats']['freq']
        theta['ALL'][idur, ishare] = (y - 3) / (y + 3)  # the kurtosis of the Calvo model in discrete time is 3 * (1 + theta)

        duration['Real'][idur, ishare] = 1 / (1 - theta['Real'][idur, ishare]) - 1
        duration['ALL'][idur, ishare] = 1 / (1 - theta['ALL'][idur, ishare]) - 1


custom_lines_here = [Line2D([0], [0], color=color, linestyle='-') for color in colors_duration]

fig, ax = plt.subplots(figsize=specs.figsize_standard)
legend = []
for k, share_free in enumerate(share_free_values):
    share_free = share_free_values[k]
    ax.plot(duration_values, 100 * dist['Real'][:, k], color=colors_duration[k])
    if k == 0:
        legend += [f'{100*share_free:.0f}\% free adjustments']
    else:
        legend += [f'{100 * share_free:.0f}\%']
ax.set_xlabel('Duration')
ax.set_ylabel('Distance (\%)')
ylim = ax.get_ylim()
ax.plot([duration_data, duration_data], ylim, linestyle=':', color='black', linewidth=2, alpha=0.3)
ax.set_ylim(ylim)
specs.save_figure(fig, 'figure_9_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
for k, share_free in enumerate(share_free_values):
    ax.plot(duration_values, 1 / (1 - theta['Real'][:, k]) - 1, color=colors_duration[k])
ax.set_xlabel('State-dependent duration')
ax.set_ylabel('Calvo duration')
ax.legend(custom_lines_here, legend, frameon=False)
for k, share_free in enumerate(share_free_values):
    ax.plot(duration_values, 1 / (1 - theta['ALL'][:, k]) - 1, color=colors_duration[k], linestyle='--')
ylim = ax.get_ylim()
ax.plot([duration_data, duration_data], ylim, linestyle=':', color='black', linewidth=2, alpha=0.3)
ax.set_ylim(ylim)
specs.save_figure(fig, 'figure_9_b', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
legend = []
for k, share_free in enumerate(share_free_values):
    ax.plot(duration_values, eigs_ratio[:, k], color=colors_duration[k])
    if k == 0:
        legend += [f'{100*share_free:.0f}\% free adjustments']
    else:
        legend += [f'{100 * share_free:.0f}\%']
ax.set_xlabel('Duration')
ax.set_ylabel('Ratio')
ylim = ax.get_ylim()
ax.plot([duration_data, duration_data], ylim, linestyle=':', color='black', linewidth=2, alpha=0.3)
ax.set_ylim(ylim)
ax.legend(custom_lines_here, legend, frameon=False, fontsize=16)
specs.save_figure(fig, 'figure_C1', folder=folder)


# %% Part 4: figures D1, D2

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

J = {model: {} for model in model_names}
J_approx = {model: {} for model in model_names}
dist = {model: {} for model in model_names}

mu_values = [0.02 / n_periods, 0.05 / n_periods]

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
legend = ['Menu cost', 'Calvo']

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}
initial_guess = {'mu_k': -5, 'sig': 0.05}

for mu_here in mu_values:

    if mu_here == mu_values[0]:
        fig_number = 'D1'
    else:
        fig_number = 'D2'

    for model, share_free in zip(model_names, share_free_values):
        if model == 'GL':
            fig_letter = 'a'
            fig_letter_2 = 'c'
        else:
            fig_letter = 'b'
            fig_letter_2 = 'd'

        parameters = parameters_base.copy()
        parameters.update({'mu': mu_here, 'la': share_free * freq_data})
        parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx)

        J_nom_out = mc.compute_td_equivalence_trend(parameters, nx, T, ss)[0]
        J[model]['nominal'] = J_nom_out
        J[model]['real'] = mc.compute_PC(J[model]['nominal'], parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=True)

        J_approx[model]['real'], theta_real, _, _, dist_real = mc.approx_jacobian(J[model]['real'], beta, Price=False, Nominal=False, Absolute=True)
        kappa = (1 - theta_real) * (1 - beta * theta_real) / theta_real

        J_approx[model]['nominal'] = td.calvo_jacobian(theta_real, beta, len(J_approx[model]['real']))

        fig, ax = plt.subplots(figsize=specs.figsize_standard)
        plot_jacobian(J[model]['nominal'], ax, color=colors_equivalence['menu_cost'])
        plot_jacobian(J_approx[model]['nominal'], ax, color=colors_equivalence['td'], linestyle='--')
        ax.set_xlim([0, 40])
        ax.set_xlabel('Quarters')
        ax.set_ylabel('Price level')
        specs.save_figure(fig, 'figure_' + fig_number + '_' + fig_letter, folder=folder)

        fig, ax = plt.subplots(figsize=specs.figsize_standard)
        plot_jacobian(J[model]['real'], ax, color=colors_equivalence['menu_cost'])
        plot_jacobian(J_approx[model]['real'], ax, color=colors_equivalence['td'], linestyle='--')
        ax.set_xlim([0, 40])
        ax.set_xlabel('Quarters')
        ax.set_ylabel('Inflation')
        if model == 'NS':
            ax.legend(custom_lines, legend)
        specs.save_figure(fig, 'figure_' + fig_number + '_' + fig_letter_2, folder=folder)


# %% Part 5: figure D4

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

J = {model: {} for model in model_names}
J_approx = {model: {} for model in model_names}
dist = {model: {} for model in model_names}

shock_prob = 0.5

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
legend = ['Menu cost', 'Calvo']

targs = {'freq': freq_data}
initial_guess = {'mu_k': -5}

for model, share_free in zip(model_names, share_free_values):

    if model == 'GL':
        fig_letter = 'a'
        fig_letter_2 = 'c'
    else:
        fig_letter = 'b'
        fig_letter_2 = 'd'

    parameters = parameters_base.copy()
    parameters.update({'tau': shock_prob, 'la': share_free * freq_data})
    parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx)

    J_nom_out, _ = mc.compute_jacobian(parameters, nx, T, h, shock_list=['gap'], output_list=['p'])
    J[model]['nominal'] = J_nom_out['p']['gap']
    J[model]['real'] = mc.compute_PC(J[model]['nominal'], parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=True)

    J_approx[model]['real'], theta_real, _, _, dist_real = mc.approx_jacobian(J[model]['real'], beta, Price=False, Nominal=False, Absolute=True)
    kappa = (1 - theta_real) * (1 - beta * theta_real) / theta_real

    J_approx[model]['nominal'] = td.calvo_jacobian(theta_real, beta, len(J_approx[model]['real']))

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J[model]['nominal'], ax, color=colors_equivalence['menu_cost'])
    plot_jacobian(J_approx[model]['nominal'], ax, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Price level')
    specs.save_figure(fig, 'figure_D4_' + fig_letter, folder=folder)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J[model]['real'], ax, color=colors_equivalence['menu_cost'])
    plot_jacobian(J_approx[model]['real'], ax, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Inflation')
    if model == 'NS':
        ax.legend(custom_lines, legend)
    specs.save_figure(fig, 'figure_D4_' + fig_letter_2, folder=folder)


# %% Part 6: figure D6

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

J = {model: {} for model in model_names}
J_approx = {model: {} for model in model_names}
dist = {model: {} for model in model_names}
kappa = {model: {} for model in model_names}

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
legend = ['Menu cost', 'Calvo']


for model, share_free in zip(model_names, share_free_values):

    if model == 'GL':
        fig_letter = 'a'
        fig_letter_2 = 'c'
    else:
        fig_letter = 'b'
        fig_letter_2 = 'd'

    for k in range(len(freq_sectors)):
        if k == 0:
            J[model][f'sector_{k + 1}'] = np.eye(T)  # first sector is too flexible, so assume it's just perfectly flex price
            J_approx[model][f'sector_{k + 1}'] = np.eye(T)
            J[model]['nominal'] = weights_sectors[k] * J[model][f'sector_{k + 1}']
            J_approx[model]['nominal'] = weights_sectors[k] * J_approx[model][f'sector_{k + 1}']
        else:
            parameters = parameters_base.copy()
            parameters.update({'la': share_free * freq_sectors[k]})
            targs = {'freq': freq_sectors[k]}
            parameters, ss = mc.calibrate_model(targs, parameters, {'mu_k': -5}, nx)

            J_nom_out = mc.compute_td_equivalence(parameters, nx, T, ss)[0]
            J[model][f'sector_{k + 1}'] = J_nom_out
            J[model][f'real_sector_{k + 1}'], J[model][f'sector_{k + 1}'] = mc.compute_PC(J[model][f'sector_{k + 1}'], parameters, ss, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=False, return_nominal=True)

            _, _, dist[model][f'sector_{k + 1}'], _, _ = mc.approx_jacobian(J[model][f'sector_{k + 1}'][:T, :T], beta, Price=True, Nominal=True, Absolute=True)
            _, theta, _, _, dist[model][f'real_sector_{k + 1}'] = mc.approx_jacobian(J[model][f'real_sector_{k + 1}'][:T, :T], beta, Price=False, Nominal=False, Absolute=True)
            J_approx[model][f'sector_{k + 1}'] = td.calvo_jacobian(theta, beta, T)

            kappa[model][f'sector_{k + 1}'] = (1 - theta) * (1 - beta * theta) / theta

            J[model]['nominal'] += weights_sectors[k] * J[model][f'sector_{k + 1}'][:T, :T]
            J_approx[model]['nominal'] += weights_sectors[k] * J_approx[model][f'sector_{k + 1}'][:T, :T]

    J[model]['real'] = mc.compute_PC(J[model]['nominal'])
    J_approx[model]['real'] = mc.compute_PC(J_approx[model]['nominal'])

    dist[model]['nominal'] = np.linalg.norm(J[model]['nominal'] - J_approx[model]['nominal'])
    dist[model]['real'] = np.linalg.norm(J[model]['real'] - J_approx[model]['real']) / np.linalg.norm(J[model]['real'])

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J[model]['nominal'], ax, color=colors_equivalence['menu_cost'])
    plot_jacobian(J_approx[model]['nominal'], ax, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Price level')
    specs.save_figure(fig, 'figure_D6_' + fig_letter, folder=folder)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J[model]['real'], ax, color=colors_equivalence['menu_cost'])
    plot_jacobian(J_approx[model]['real'], ax, color=colors_equivalence['td'], linestyle='--')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Inflation')
    if model == 'NS':
        ax.legend(custom_lines, legend)
    specs.save_figure(fig, 'figure_D6_' + fig_letter_2, folder=folder)


# %% Part 7: figure D5

J = {}
J_approx = {}
dist = {}

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
legend = ['Menu cost', 'Calvo']

tau = 1
parameters = parameters_base.copy()
parameters.update({'tau': tau, 'la': 0})
parameters.pop('shift')

targs = {'freq': freq_data}
parameters, ss = mp.calibrate_model(targs, parameters, {'mu_k': -3.8644342766540825}, optimize=False)

J_nom, _ = mp.compute_jacobian(parameters, ss['x_grid'], T, h, shock_list=['gap'], output_list=['p'])
J['nominal'] = J_nom['p']['gap']
J['real'] = mp.compute_PC(J['nominal'], ss['x_grid'], parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=True)
J['real'] = J['real'][:T, :T]

J_approx['real'], theta_real, _, _, dist_real = mc.approx_jacobian(J['real'], beta, Price=False, Nominal=False, Absolute=True)
kappa = (1 - theta_real) * (1 - beta * theta_real) / theta_real

J_approx['nominal'] = td.calvo_jacobian(theta_real, beta, len(J_approx['real']))
theta_nominal = theta_real


fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J['nominal'], ax, color=colors_equivalence['menu_cost'])
plot_jacobian(J_approx['nominal'], ax, color=colors_equivalence['td'], linestyle='--')
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
specs.save_figure(fig, 'figure_D5_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J['real'], ax, color=colors_equivalence['menu_cost'])
plot_jacobian(J_approx['real'], ax, color=colors_equivalence['td'], linestyle='--')
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Inflation')
ax.legend(custom_lines, legend)
specs.save_figure(fig, 'figure_D5_b', folder=folder)


# %% Part 8: figure D7

shock_size = [0.025, 0.05, 0.1]

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]

for model, share_free in zip(model_names, share_free_values):
    if model == 'GL':
        fig_letters = ['a', 'c', 'e']
    else:
        fig_letters = ['b', 'd', 'f']

    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data})
    if model == 'GL':
        parameters.update({'mu_k': -5.122741752515663, 'sig': 0.045874035207593526})
    else:
        parameters.update({'mu_k': -2.9710458486831417, 'sig': 0.06035491762766833})

    J, ss, *_ = mc.compute_td_equivalence(parameters, nx, T)
    J_approx, _, _, _, _ = mc.approx_jacobian(J, beta, Price=True, Nominal=True, Absolute=True)

    for k, size in enumerate(shock_size):
        fig, ax = plt.subplots(figsize=specs.figsize_standard)
        for p in shock_persistence['nominal']:
            shock = size * (p ** np.arange(T))
            irf_linear = J_approx @ shock
            irf_nonlinear = mc.mit_shock(parameters, nx, shock_paths={'gap': shock}, output_list=['p'], ss=ss)['p']
            ax.plot(irf_nonlinear, color=colors_equivalence['menu_cost'])
            ax.plot(irf_linear, color=colors_equivalence['td'], linestyle='--')
        ax.set_xlim([0, 12])
        ax.set_xlabel('Quarters')
        ax.set_ylabel('Price')
        if model == 'NS' and k == 0:
            ax.legend(custom_lines, ['Nonlinear state-dependent', 'Linear Calvo approx.'])
        specs.save_figure(fig, 'figure_D7_' + fig_letters[k], folder=folder)


# %% Part 9: figures E1, E2 (requires Part 2 above)

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

elast = 4
trend_inflation_values = [0, 0.02 / n_periods, 0.05 / n_periods]

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}  # calibration targets

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
legend = ['Menu cost', 'Calvo']

rho_i = 0.5

D = np.diag(np.ones(T_extend)) - np.diag(np.ones(T_extend - 1), -1)


irf_all = {}
irf_all_approx = {}

for model, share_free in zip(model_names, share_free_values):
    irf_all[model] = list()
    irf_all_approx[model] = list()

    for iter, trend_inflation in enumerate(trend_inflation_values):

        if iter == 0:
            parameters = parameters_base.copy()
            parameters.update({'la': share_free * freq_data, 'elast': elast, 'mu': trend_inflation})
            if model == 'GL':
                initial_guess = {'mu_k': -5.5945372186590720, 'sig': 0.04582172108191538}
            else:
                initial_guess = {'mu_k': -3.4895061850831284, 'sig': 0.06005034446810007}

            parameters, _ = mc.calibrate_model(targs=targs, initial_guess=initial_guess, parameters=parameters, nx=nx)

        else:
            parameters.update({'mu': trend_inflation})

        shock_list = ['gap', 'discount', 'shift']
        output_list = ['p', 'p_star', 'freq']
        J_pos, ss = mc.compute_jacobian(parameters, nx, T, h, shock_list=shock_list, output_list=output_list)
        J_neg, _ = mc.compute_jacobian(parameters, nx, T, -h, shock_list=shock_list, output_list=output_list)

        # compute symmetric derivative for accuracy
        J = {out: {shock: 0.5 * (J_pos[out][shock] + J_neg[out][shock]) for shock in shock_list} for out in output_list}

        # add delta
        output_list += ['delta']
        J['delta'] = {}
        for shock in shock_list:
            J['delta'][shock] = elast * (J['p_star'][shock] - J['p'][shock])

        # make sure columns sum to 1
        nx = len(ss['x_grid'])
        permanent_shock = mc.permanent_gap_shock(parameters, nx, T, h, ss['x_grid'], ss['Pi'], ss['g'])
        J['p']['gap'] = J['p']['gap'] * permanent_shock[:, np.newaxis] / np.sum(J['p']['gap'], axis=1)[:, np.newaxis]

        # extend matrix
        for output in output_list:
            for shock in shock_list:
                J[output][shock] = mc.extend_matrix(J[output][shock], T_extend)

        # adjust price jacobians to real shocks
        J['p']['discount'] = np.linalg.solve(np.eye(T_extend) - J['p']['gap'], J['p']['discount'])
        J['p']['shift'] = np.linalg.solve(np.eye(T_extend) - J['p']['gap'], J['p']['shift'])
        J['p']['gap'] = np.linalg.solve(np.eye(T_extend) - J['p']['gap'], J['p']['gap'])

        # adjust other jacobians to real shocks
        J['freq']['gap'] = J['freq']['gap'] @ (np.eye(T_extend) + J['p']['gap'])
        J['delta']['gap'] = J['delta']['gap'] @ (np.eye(T_extend) + J['p']['gap'])

        # add inflation jacobians
        output_list += ['pi']
        J['pi'] = {shock: D @ J['p'][shock] for shock in shock_list}

        for output in output_list:
            for shock in shock_list:
                J[output][shock] = J[output][shock][:T, :T]

        J_gap_approx, theta, beta_here, _, _, _ = mc.approx_jacobian(J['pi']['gap'], beta=None, Price=False, Nominal=False, Absolute=True)
        kappa = (1 - theta) * (1 - beta_here * theta) / theta
        J_approx = {'pi': {'gap': J_gap_approx, 'discount': np.zeros((T, T)), 'shift': np.zeros((T, T))}}

        irf = compute_ge_full_irfs(J, parameters, ss, rho_i, T)
        irf_approx = compute_ge_irfs(J_approx, rho_i, T)

        irf_all[model].append(irf)
        irf_all_approx[model].append(irf_approx)


for model in ['GL', 'NS']:
    if model == 'GL':
        fig_letters = ['a', 'c']
    else:
        fig_letters = ['b', 'd']

    xlim = 8
    fig, ax = plt.subplots(nrows=2, figsize=specs.figsize_standard)
    ax[0].plot(irf_nk[model]['pi'], linestyle='-', color='royalblue')
    ax[0].plot(irf_all[model][0]['pi'], linestyle='-.', color='darkblue')
    ax[0].plot(irf_nk_approx[model]['pi'], linestyle='--', color=colors_equivalence['td'])
    ax[0].set_xlim([0, xlim])
    ax[0].set_ylabel('Inflation')
    # ax[0].set_ylim([-0.2, 0.02])
    # ax[0].set_yticks([-0.2, -0.1, 0])
    ax[1].plot(irf_nk[model]['y'], linestyle='-', color='royalblue')
    ax[1].plot(irf_all[model][0]['y'], linestyle='-.', color='darkblue')
    ax[1].plot(irf_nk_approx[model]['y'], linestyle='--', color=colors_equivalence['td'])
    ax[1].set_xlabel('Quarters')
    ax[1].set_xlim([0, xlim])
    if model == 'NS':
        ax[1].set_ylim([-0.06, 0.005])
    # ax[1].set_yticks([-0.2, -0.1, 0])
    ax[1].set_ylabel('Output')
    if model == 'NS':
        ax[1].legend(['Small id. shocks', 'Large id. shocks', 'Calvo approx.'], loc=4)
    specs.save_figure(fig, 'figure_E1_' + fig_letters[0])

    xlim = 8
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(irf_all[model][1]['pi'], linestyle='-', color='mediumblue')
    ax.plot(irf_all_approx[model][1]['pi'], linestyle='--', color='firebrick')
    ax.set_xlim([0, xlim])
    ax.set_ylabel('Inflation')
    ax.set_xlabel('Quarters')
    specs.save_figure(fig, 'figure_E2_' + fig_letters[0])

    xlim = 8
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(irf_all[model][2]['pi'], linestyle='-', color='mediumblue')
    ax.plot(irf_all_approx[model][2]['pi'], linestyle='--', color='firebrick')
    ax.set_xlim([0, xlim])
    ax.set_ylabel('Inflation')
    ax.set_xlabel('Quarters')
    if model == 'NS':
        ax.legend(['Calvo approx. (no trend inflation)', 'SD with trend inflation'], loc=4)
    specs.save_figure(fig, 'figure_E2_' + fig_letters[1])


# %% Part 10: figure A1

linestyles = ['-', '--', '-.']

J = 500  # truncation
xbar = 0.1
sig = 0.05
zeta = 0.1
nx_here = 1000
x_grid = np.linspace(-xbar, xbar, 1000)

b = np.zeros(J)
la = np.zeros(J)
phi = []
for j in range(1, J + 1):
    phi_here = 1 / np.sqrt(xbar) * np.sin((x_grid + xbar) / (2 * xbar) * j * np.pi)
    phi.append(phi_here)
    la[j - 1] = -(zeta + (0.5 * sig ** 2) * (j * np.pi / (2 * xbar)) ** 2)
    if j % 2 == 0:
        b[j - 1] = 4 * (xbar ** (3 / 2)) / (j * np.pi)

if zeta == 0:
    p = (sig / xbar) ** 2
else:
    alpha = np.sqrt(2 * zeta / sig ** 2)
    p = zeta * np.cosh(alpha * xbar) / (np.cosh(alpha * xbar) - 1)


def E(t):
    out = np.zeros_like(x_grid)
    for j in range(J):
        out += b[j] * np.exp(la[j] * t) * phi[j]
    return -out

fig, ax = plt.subplots(figsize=specs.figsize_standard)
t_vals = [0.1, 0.5, 1.0]
legend = []
for k, t in enumerate(t_vals):
    ax.plot(x_grid, E(t), color=colors_Es[k+1], linestyle=linestyles[k])
    legend += [f't = {t}']
if add_titles:
    ax.set_title('E(x,t)')
ax.legend(legend)
ax.set_ylim([-0.08, 0.08])
ax.set_yticks([-0.06, -0.03, 0, 0.03, 0.06])
ax.set_xlabel('Price gap')
specs.save_figure(fig, 'figure_A1_a', folder=folder)

tmax = 4
n = 100
t_grid = np.linspace(0, tmax, n)

alternating = np.array(int((J / 4)) * [0, -1, 0, 1])
only_even = np.array(int((J / 2)) * [0, 1])

# surival functions up to normalization
S_ext = np.array([2 * np.sum(only_even * np.exp(la * t)) for t in t_grid])
S_int = -np.array([2 * np.sum(alternating * np.exp(la * t)) for t in t_grid])
S_int[0] = S_int[1]

S_p_ext = np.array([2 * np.sum(only_even * la * np.exp(la * t)) for t in t_grid])
S_p_int = -np.array([2 * np.sum(alternating * la * np.exp(la * t)) for t in t_grid])

la_ext = -S_p_ext / S_ext
la_int = -S_p_int / S_int
la_int[0] = la_int[1]

integral_S_ext = -np.sum(2 * only_even / la)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(t_grid, S_ext, color=colors_equivalence['extensive'], linestyle=linestyles[0])
ax.plot(t_grid, S_int, color=colors_equivalence['intensive'], linestyle=linestyles[1])
ax.set_ylim([0, 3])
ax.set_xlim([0, tmax])
ax.set_xlabel('Quarters')
# ax.legend(['Extensive margin', 'Intensive margin'])
specs.save_figure(fig, 'figure_A1_b', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(t_grid, la_ext, color=colors_equivalence['extensive'], linestyle=linestyles[0])
ax.plot(t_grid, la_int, color=colors_equivalence['intensive'], linestyle=linestyles[1])
ax.set_ylim([0, 3])
ax.set_xlim([0, tmax])
ax.set_xlabel('Quarters')
ax.legend(['Extensive margin', 'Intensive margin'])
specs.save_figure(fig, 'figure_A1_c', folder=folder)


# %% Part 11: table D1

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]
targs = {'freq': freq_data}  # calibration targets

N = 10000  # number of observations to simulate

roots_list = [[0.8], [0]]
sd_list = [1 - 0.8 ** 2, 1]
# sd_list = np.ones_like(roots_list)
n = len(roots_list)
mc_irf = [np.concatenate([ar_irf(roots, sd, T), np.zeros(N - T)]) for sd, roots in zip(sd_list, roots_list)]


def write_col(sheet, entries, col):
    for row, entry in enumerate(entries):
        sheet.write(row, col, entry)


workbook = xlsxwriter.Workbook(f'tables/table_D1.xlsx')
worksheet = workbook.add_worksheet()

write_col(worksheet, ['', 'kappa', 'forward', 'lagged', 'R2'], col=0)

for model, share_free in zip(model_names, share_free_values):
    np.random.seed(0)

    if model == 'GL':
        col_start = 1
    else:
        col_start = 6

    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data})
    if model == 'GL':
        initial_guess = {'mu_k': -4.950639364609949}
    else:
        initial_guess = {'mu_k': -3.3474886932977386}
    parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx)

    J_nom, *_ = mc.compute_td_equivalence(parameters, nx, T, ss=ss, h=h)
    J_real = mc.compute_PC(J_nom, parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=True, cut_back=True)
    J_approx, theta_approx, *_ = mc.approx_jacobian(J_real, beta, Nominal=False, Price=False)

    pi_irf = [np.concatenate([J_real @ irf[:T], np.zeros(N - T)]) for irf in mc_irf]
    pi_irf_approx = [np.concatenate([J_approx @ irf[:T], np.zeros(N - T)]) for irf in mc_irf]

    pi_sim = np.zeros(N)
    pi_approx_sim = np.zeros(N)
    E_pi_sim = np.zeros(N)
    mc_sim = np.zeros(N)
    for t in range(N):
        shocks = np.random.randn(n)
        for j in range(n):
            mc_sim[t:] += shocks[j] * mc_irf[j][:N - t]
            pi_sim[t:] += shocks[j] * pi_irf[j][:N - t]
            pi_approx_sim[t:] += shocks[j] * pi_irf_approx[j][:N - t]
            E_pi_sim[t + 1:] += shocks[j] * pi_irf[j][1:N - t]

    pi_lag = pi_sim[:-2]
    pi = pi_sim[1:-1]
    pi_for = E_pi_sim[2:]
    real_mc = mc_sim[1:-1]

    beta = 1 / (1 + parameters['r'])

    kappa = (1 - theta_approx) * (1 - beta * theta_approx) / theta_approx
    pi_fit = beta * pi_for + kappa * real_mc
    R2 = 1 - np.sum((pi - pi_fit) ** 2) / np.sum(pi ** 2)

    write_col(worksheet, ['K approx.', kappa, beta, 0, R2], col=col_start)

    reg = sm.OLS(pi - beta * pi_for, real_mc)
    res = reg.fit()
    se = np.sqrt(np.diag(res.cov_params()))
    pi_fit = beta * pi_for + res.params[0] * real_mc
    R2 = 1 - np.sum((pi - pi_fit) ** 2) / np.sum(pi ** 2)

    write_col(worksheet, ['beta = 0.99, gamma = 0', res.params[0], beta, 0, R2], col_start + 1)

    X = np.column_stack([real_mc, pi_for])
    reg = sm.OLS(pi, X)
    res = reg.fit()
    se = np.sqrt(np.diag(res.cov_params()))

    write_col(worksheet, ['gamma = 0', res.params[0], res.params[1], 0, res.rsquared], col=col_start + 2)

    X = np.column_stack([real_mc, pi_for, pi_lag])
    reg = sm.OLS(pi, X)
    res = reg.fit()
    se = np.sqrt(np.diag(res.cov_params()))

    write_col(worksheet, ['Unrestricted', res.params[0], res.params[1], res.params[2], res.rsquared], col=col_start + 3)

workbook.close()


# %% Part 12: figures C2, D3

model_names = ['GL', 'NS']

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}  # calibration targets

mu_values = np.linspace(0, 0.15, 30) / n_periods

for model in model_names:
    if model == 'GL':
        share_free = 0
        slope_here = 0.16595705394631
        fig_letter = 'a'
        fig_letter_2 = 'c'
    else:
        share_free = 0.75
        slope_here = 0.377270260460157
        fig_letter = 'b'
        fig_letter_2 = 'd'

    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data})
    if model == 'GL':
        parameters.update({'mu_k': -5.122741752515663, 'sig': 0.045874035207593526})
    else:
        parameters.update({'mu_k': -2.9710458486831417, 'sig': 0.06035491762766833})

    theta_values = np.zeros_like(mu_values)
    kappa_values = np.zeros_like(mu_values)
    freq_values = np.zeros_like(mu_values)
    dist_values = np.zeros_like(mu_values)
    w_int_values = np.zeros_like(mu_values)
    w_upper_values = np.zeros_like(mu_values)
    w_lower_values = np.zeros_like(mu_values)
    asymptotic_hazard = np.zeros_like(mu_values)
    frequency = np.zeros_like(mu_values)
    kappa_no_trend = np.zeros_like(mu_values)

    ss = mc.steady_state(parameters, nx)

    for i, mu in enumerate(mu_values):
        parameters.update({'mu': mu})

        J_nom, ss_new, _, _, _, f_int, f_lower, f_upper, w_int, w_lower, w_upper = mc.compute_td_equivalence_trend(parameters, nx, T, ss=None, h=h)

        hazards_int = 1 - f_int[1:] / np.maximum(f_int[:-1], 1e-8)
        hazards_lower = 1 - f_lower[1:] / np.maximum(f_lower[:-1], 1e-8)
        hazards_upper = 1 - f_upper[1:] / np.maximum(f_upper[:-1], 1e-8)

        J_real = mc.compute_PC(J_nom, parameters, ss=ss, h=h, T_extend=T_extend, permanent_shock_to_normalize=False, cut_back=True)
        J_approx, theta_calvo, _, _, dist = mc.approx_jacobian(J_real, beta, Price=False, Nominal=False, Absolute=True)

        if i == 0:
            theta_calvo_0 = theta_calvo
            freq_calvo_0 = 1 - theta_calvo
            dur_calvo_0 = 1 / freq_calvo_0 - 1
            freq_sd_0 = ss_new['stats']['freq']
            dur_sd_0 = 1 / freq_sd_0 - 1

        kappa = (1 - theta_calvo) * (1 - beta * theta_calvo) / theta_calvo

        theta_values[i] = theta_calvo
        kappa_values[i] = kappa
        freq_values[i] = ss_new['stats']['freq']
        dist_values[i] = dist
        w_int_values[i] = w_int
        w_upper_values[i] = w_upper
        w_lower_values[i] = w_lower
        asymptotic_hazard[i] = hazards_int[10]  # 10 appears to be large enough for convergence
        frequency[i] = ss_new['stats']['freq']

        dur_sd = 1 / frequency[i] - 1
        dur_calvo = slope_here * (dur_sd - dur_sd_0) + dur_calvo_0
        freq_calvo = 1 / (1 + dur_calvo)
        theta_no_trend = 1 - freq_calvo
        kappa_no_trend[i] = (1 - theta_no_trend) * (1 - beta * theta_no_trend) / theta_no_trend


    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(100 * n_periods * mu_values, kappa_values, color='darkblue')
    ax.plot(100 * n_periods * mu_values, kappa_no_trend, color='green', ls='--')
    ax.set_xlabel('Trend inflation (\% p.a.)')
    ax.set_ylabel('Calvo slope')
    if model == 'NS':
        ax.legend(['With trend inflation', 'No trend inflation, same freq.'])
    specs.save_figure(fig, 'figure_D3_' + fig_letter)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(100 * n_periods * mu_values, dist_values, color='darkblue')
    ax.set_xlabel('Trend inflation (\% p.a.)')
    ax.set_ylabel('Distance')
    specs.save_figure(fig, 'figure_D3_' + fig_letter_2)

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(100 * n_periods * mu_values, w_int_values, color=colors_equivalence['intensive'])
    ax.plot(100 * n_periods * mu_values, w_lower_values, color=colors_equivalence['extensive'], ls='--')
    ax.plot(100 * n_periods * mu_values, w_upper_values, color=colors_equivalence['actual_hazards'], ls='-.')
    ax.set_xlabel('Trend inflation (\% p.a.)')
    ax.set_ylabel('Weights')
    if model == 'NS':
        ax.legend(['Intensive', 'Lower', 'Upper'])
    specs.save_figure(fig, 'figure_C2_' + fig_letter)

    fig, ax1 = plt.subplots(figsize=specs.figsize_standard)
    ax1.set_xlabel('Trend inflation (\% p.a.)')
    # ax1.set_ylabel('Frequency', color=colors_equivalence['intensive'])
    ax1.plot(100 * n_periods * mu_values, frequency, color=colors_equivalence['intensive'])
    ax1.tick_params(axis='y', colors=colors_equivalence['intensive'])
    ax1.spines['left'].set_color(colors_equivalence['intensive'])
    ax2 = ax1.twinx()
    # ax2.set_ylabel('Asymptotic virtual hazard', color=colors_equivalence['extensive'])
    ax2.plot(100 * n_periods * mu_values, asymptotic_hazard, color=colors_equivalence['extensive'], ls='--')
    ax2.tick_params(axis='y', colors=colors_equivalence['extensive'])
    ax1.spines['left'].set_color(colors_equivalence['extensive'])
    ax2.grid(False)
    if model == 'NS':
        ax1.legend([Line2D([0], [0], color=colors_equivalence['intensive'], linestyle='-'),
                    Line2D([0], [0], color=colors_equivalence['extensive'], linestyle='--')],
                   ['Frequency (left)', 'Asymptotic virtual hazard (right)'])
    specs.save_figure(fig, 'figure_C2_' + fig_letter_2)


# %% Part 13: figure D8

T_irf = 200

shock_size = [0.05 / n_periods, 0.1 / n_periods]
persistence = 0.8
mc_path = persistence ** np.arange(T_irf)


def real_mc_shock(mc_path, p_linear, G, ss, parameters, tol=1e-6, maxit=20, T_cut=50, initial_distribution=None):
    nx = len(ss['x_grid'])
    p_nonlinear = p_linear.copy()

    for it in range(maxit):
        MC_path = p_nonlinear + mc_path

        irf = mc.mit_shock(parameters, nx, shock_paths={'gap': MC_path}, output_list=['p'], ss=ss, return_only_irf=True, initial_distribution=initial_distribution)
        p_nonlinear_new = irf['p']

        dist = np.max(np.abs(p_nonlinear[:-T_cut] - p_nonlinear_new[:-T_cut]))
        # print(dist)
        if dist < tol:
            break
        # p_nonlinear = smooth * p_nonlinear + (1 - smooth) * p_nonlinear_new
        p_nonlinear[:-T_cut] = p_nonlinear[:-T_cut] - G[:-T_cut, :-T_cut] @ (p_nonlinear[:-T_cut] - p_nonlinear_new[:-T_cut])
        p_nonlinear[T_cut:] = p_nonlinear[T_cut]

    _, _, _, _, g_transition = mc.mit_shock(parameters, nx, shock_paths={'gap': MC_path}, output_list=['p'], ss=ss, return_only_irf=False, initial_distribution=initial_distribution)

    if initial_distribution is not None:
        p_nonlinear -= np.sum(initial_distribution * ss['x_grid'])

    return p_nonlinear, g_transition


model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}  # calibration targets

for model, share_free in zip(model_names, share_free_values):
    if model == 'GL':
        fig_letters = ['a', 'c', 'e']
    else:
        fig_letters = ['b', 'd', 'f']

    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data})
    if share_free == 0:
        initial_guess = {'mu_k': -5.122636000706496, 'sig': 0.045874030709693224}
    else:
        initial_guess = {'mu_k': -2.9707928427194865, 'sig': 0.060354917627457566}

    parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx, x_grid=x_grid_large)

    J_nom, _, _, _, _, _, _, _ = mc.compute_td_equivalence(parameters, nx, T_irf, ss=ss, h=h)
    J_real_large, J_nom_large = mc.compute_PC(J_nom, parameters, ss=ss, h=h, T_extend=T_extend,
                                              permanent_shock_to_normalize=True, cut_back=False, return_nominal=True)
    J_real = J_real_large[:T_irf, :T_irf]

    J_nom_approx, _, _, _, _ = mc.approx_jacobian(J_nom, beta, Price=True, Nominal=True, Absolute=True)
    J_real_approx, _, _, _, dist_real = mc.approx_jacobian(J_real, beta, Price=False, Nominal=False, Absolute=True)

    D = np.diag(np.ones(len(J_real_large))) - np.diag(np.ones(len(J_real_large) - 1), -1)
    G = np.linalg.solve(D @ J_nom_large, J_real_large)
    G = G[:T_irf, :T_irf]

    p_linear = np.cumsum(J_real_approx @ mc_path)
    mc_path *= shock_size[0] / p_linear[0]
    p_linear *= shock_size[0] / p_linear[0]

    p_sd_linear = np.cumsum(J_real @ mc_path)

    p_nonlinear, g_transition = real_mc_shock(mc_path, p_linear, G, ss, parameters)
    p_nonlinear2, _ = real_mc_shock(mc_path, p_linear, G, ss, parameters, initial_distribution=g_transition[4])

    xlim = 13
    D = np.diag(np.ones(T_irf)) - np.diag(np.ones(T_irf - 1), -1)
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(n_periods * (D @ p_linear)[:xlim], linestyle='-', color='firebrick')
    ax.plot(n_periods * (D @ p_sd_linear)[:xlim], linestyle='--', color='royalblue')
    ax.plot(n_periods * (D @ p_nonlinear)[:xlim], linestyle='-.', color='darkblue')
    ax.set_ylabel('Annualized inflation')
    ax.set_xlabel('Quarters')
    ax.set_xticks([0, 2, 4, 6, 8, 10, 12])
    if model == 'NS':
        ax.legend(['Linear Calvo', 'Linear menu cost', 'Nonlinear menu cost'])
    specs.save_figure(fig, 'figure_D8_' + fig_letters[0])

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(n_periods * (D @ p_linear)[:xlim], linestyle='-', color='firebrick')
    ax.plot(n_periods * (D @ p_sd_linear)[:xlim], linestyle='--', color='royalblue')
    ax.plot(n_periods * (D @ p_nonlinear2)[:xlim], linestyle='-.', color='darkblue')
    ax.set_ylabel('Annualized inflation')
    ax.set_xlabel('Quarters')
    ax.set_xticks([0, 2, 4, 6, 8, 10, 12])
    specs.save_figure(fig, 'figure_D8_' + fig_letters[2])

    p_linear = np.cumsum(J_real_approx @ mc_path)
    mc_path *= shock_size[1] / p_linear[0]
    p_linear *= shock_size[1] / p_linear[0]

    p_sd_linear = np.cumsum(J_real @ mc_path)
    p_nonlinear, g_transition = real_mc_shock(mc_path, p_linear, G, ss, parameters)

    xlim = 13
    D = np.diag(np.ones(T_irf)) - np.diag(np.ones(T_irf - 1), -1)
    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    ax.plot(n_periods * (D @ p_linear)[:xlim], linestyle='-', color='firebrick')
    ax.plot(n_periods * (D @ p_sd_linear)[:xlim], linestyle='--', color='royalblue')
    ax.plot(n_periods * (D @ p_nonlinear)[:xlim], linestyle='-.', color='darkblue')
    ax.set_ylabel('Annualized inflation')
    ax.set_xlabel('Quarters')
    ax.set_xticks([0, 2, 4, 6, 8, 10, 12])
    specs.save_figure(fig, 'figure_D8_' + fig_letters[1])


# %% Part 14: figure E3

beta_here = 0.95
theta_here = 0.75
eps_here = 4

T_here = 1500
T_cut = 300

n_points = 21
mu_values_here = np.linspace(0, 0.05 / n_periods, n_points)

J_list = list()
M_list = list()
J_approx_list = list()
M_approx_list = list()
theta_approx_list = list()
beta_approx_list = list()
dist_approx_list = list()


for k, mu_here in enumerate(mu_values_here):
    # print(k)
    M_here, J_here, _, _ = td.calvo_trend_inflation(theta_here, mu_here, 0.0, beta_here, eps_here, T_here)
    J_here = J_here[:T_cut, :T_cut]
    J_approx_here, theta_approx, beta_approx, _, _, dist_approx = mc.approx_jacobian(J_here, beta=None, Price=False, Nominal=False)

    M_approx_here = td.calvo_jacobian(theta_approx, beta_approx, T_here)

    J_list.append(J_here)
    M_list.append(M_here)
    J_approx_list.append(J_approx_here)
    M_approx_list.append(M_approx_here)
    theta_approx_list.append(theta_approx)
    beta_approx_list.append(beta_approx)
    dist_approx_list.append(dist_approx)

theta_approx = np.array(theta_approx_list)
beta_approx = np.array(beta_approx_list)
kappa_approx = (1 - theta_approx) * (1 - theta_approx * beta_approx) / theta_approx

idx = np.argmin(np.abs(mu_values_here - 0.02 / n_periods))
fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(M_list[idx], ax, color=colors_equivalence['menu_cost'])
plot_jacobian(M_approx_list[idx], ax, color=colors_equivalence['td'], linestyle='--')
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--')]
ax.legend(custom_lines, ['With trend inflation', 'No trend approx.'])
specs.save_figure(fig, 'figure_E3_a', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(J_list[idx], ax, color=colors_equivalence['menu_cost'])
plot_jacobian(J_approx_list[idx], ax, color=colors_equivalence['td'], linestyle='--')
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Inflation')
specs.save_figure(fig, 'figure_E3_b', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(100 * n_periods * mu_values_here, beta_approx, color=colors_equivalence['menu_cost'])
ax.set_xlabel('Inflation (\% p.a.)')
ax.set_ylabel('Discount factor')
specs.save_figure(fig, 'figure_E3_c', folder=folder)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(100 * n_periods * mu_values_here, kappa_approx, color=colors_equivalence['menu_cost'])
ax.set_xlabel('Inflation (\% p.a.)')
ax.set_ylabel('Phillips curve slope')
ax.set_yticks([0.06, 0.07, 0.08, 0.09, 0.1])
specs.save_figure(fig, 'figure_E3_d', folder=folder)


# %% Part 15: figure E4

beta_here = 0.95
mu_here = 0.02 / n_periods
eps_here = 4

model_names = ['GL', 'NS']
share_free_values = [0, 0.75]

custom_lines = [Line2D([0], [0], color=colors_equivalence['menu_cost'], linestyle='-'),
                Line2D([0], [0], color=colors_equivalence['td'], linestyle='--'),
                Line2D([0], [0], color='green', linestyle='-.')]
legend = ['Menu cost with trend', 'Calvo with trend', 'Calvo, no trend']

targs = {'freq': freq_data, 'med_abs_dp': med_abs_dp_data}  # calibration targets

for model, share_free in zip(model_names, share_free_values):
    if model == 'GL':
        fig_letter = 'a'
    else:
        fig_letter = 'b'
    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data, 'r': 1 / beta_here - 1, 'elast': eps_here})
    if model == 'GL':
        initial_guess = {'mu_k': -5.122636000706496, 'sig': 0.045874030709693224}
    else:
        initial_guess = {'mu_k': -2.9707928427194865, 'sig': 0.060354917627457566}
    parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx)

    shock_list = ['gap']
    output_list = ['p']
    J_nom, ss = mc.compute_jacobian(parameters, nx, T, h, shock_list=shock_list, output_list=output_list)
    J_nom = J_nom['p']['gap']

    permanent_shock = mc.permanent_gap_shock(parameters, nx, T, h, ss['x_grid'], ss['Pi'], ss['g'])
    J_nom = J_nom * permanent_shock[:, np.newaxis] / np.sum(J_nom, axis=1)[:, np.newaxis]

    J_nom = mc.extend_matrix(J_nom, T_extend)
    J = np.linalg.solve(np.eye(T_extend) - J_nom, J_nom)
    J[1:, :] -= J[:-1, :]
    J = J[:T, :T]

    J_approx, theta, beta_calvo, *_ = mc.approx_jacobian(J, beta=None, Price=False, Nominal=False, Absolute=True)
    _, J_approx_trend, *_ = td.calvo_trend_inflation(theta, mu_here, 0, beta_calvo, eps=4, T=T)

    parameters.update({'mu': mu_here})
    J_nom_trend, ss = mc.compute_jacobian(parameters, nx, T, h, shock_list=shock_list, output_list=output_list)
    J_nom_trend = J_nom_trend['p']['gap']

    permanent_shock = mc.permanent_gap_shock(parameters, nx, T, h, ss['x_grid'], ss['Pi'], ss['g'])
    J_nom_trend = J_nom_trend * permanent_shock[:, np.newaxis] / np.sum(J_nom_trend, axis=1)[:, np.newaxis]

    J_nom_trend = mc.extend_matrix(J_nom_trend, T_extend)
    J_trend = np.linalg.solve(np.eye(T_extend) - J_nom_trend, J_nom_trend)
    J_trend[1:, :] -= J_trend[:-1, :]
    J_trend = J_trend[:T, :T]

    fig, ax = plt.subplots(figsize=specs.figsize_standard)
    plot_jacobian(J_trend, ax, color=colors_equivalence['menu_cost'], linestyle='-')
    plot_jacobian(J_approx_trend, ax, color=colors_equivalence['td'], linestyle='--')
    plot_jacobian(J_approx, ax, color='green', linestyle='-.')
    ax.set_xlim([0, 40])
    ax.set_xlabel('Quarters')
    ax.set_ylabel('Inflation')
    if model == 'NS':
        ax.legend(custom_lines, legend)
    specs.save_figure(fig, 'figure_E4_' + fig_letter)


# %% Part 16: figure C3

share_free_values = [0, 0.5, 0.75, 0.9, 0.99, 0.995]

haz_e = []
haz_i = []
alpha = []

targs = {'freq': freq_data}

for share_free in share_free_values:
    # print(share_free)
    parameters = parameters_base.copy()
    parameters.update({'la': share_free * freq_data})
    initial_guess = {'mu_k': -3.0}
    parameters, ss = mc.calibrate_model(targs, parameters, initial_guess, nx)

    _, _, _, _, f_ext, f_int, w_ext, w_int = mc.compute_td_equivalence(parameters, nx, T, ss=ss, h=h)

    hazards_ext = 1 - f_ext[1:] / np.maximum(f_ext[:-1], 1e-8)
    hazards_int = 1 - f_int[1:] / np.maximum(f_int[:-1], 1e-8)

    haz_e.append(hazards_ext)
    haz_i.append(hazards_int)
    alpha.append(w_ext)


color_list_here = ['firebrick', 'black', 'purple', 'tab:blue', 'green', 'grey']
ls_list_here = ['-', '--', '-.', ':', (0, (1, 5)), (5, (10, 3))]

tmax = 10
fig, ax = plt.subplots()
for k, s in enumerate(share_free_values):
    ax.plot(haz_i[k][:tmax+1], color=color_list_here[k], ls=ls_list_here[k])
ax.set_xlabel('Quarters')
ax.set_ylabel('Adjustment hazards')
specs.save_figure(fig, filename='figure_C3_b')

legend = []
fig, ax = plt.subplots()
for k, s in enumerate(share_free_values):
    ax.plot(haz_e[k][:tmax+1], color=color_list_here[k], ls=ls_list_here[k])
    if k == 0:
        legend += [f'{100*s:.1f}\% free adj., ' + r'$\alpha = $' + f'{alpha[k]:.2f}']
    else:
        legend += [f'{100*s:.1f}\%, ' + r'$\alpha = $' + f'{alpha[k]:.2f}']
ax.legend(legend, loc=1)
ax.set_xlabel('Quarters')
ax.set_ylabel('Adjustment hazards')
specs.save_figure(fig, filename='figure_C3_a')
